Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.vmap: convert mapped input arguments to array #25835

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jan 10, 2025

Fixes #25745

I'm not sure whether this is the right approach – I'm going to run our test suite to see if there are unexpected issues.

@jakevdp jakevdp self-assigned this Jan 10, 2025
@jakevdp jakevdp requested a review from mattjj January 10, 2025 17:16
@jakevdp
Copy link
Collaborator Author

jakevdp commented Jan 15, 2025

Notes from convo with @mattjj – this is an OK fix to the vmap issue, but doesn't really get to the root of the problem (e.g. the same thing occurs with jvp of an identity function.

That said, the result of transform(lambda x: x)(np_array) is a corner case that may not be all that important in practice. However, the issue of transform(jnp.asarray)(np_array) returning a non-JAX array is more surprising. There are a couple ways we could address this:

  1. convert numpy arrays on input (as in this PR)
  2. convert numpy arrays on output
  3. ensure that jnp.asarray lowers to some primitive – probably something similar to the existing copy_p.

@mattjj things (1) or (2) are fine, with a slight preferance toward (2). (3) may be better, but it's hard to predict whether it may have downsides we're not anticipating. I'm going to explore (3) a bit before moving forward here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap(jnp.asarray)(numpy_array) does not return a JAX array
1 participant